from utils import *
import sys
from math import inf
import time
from tabulate import tabulate
from sklearn.cluster import KMeans

for trial in range(20):
    m = 6
    K = 10
    M = 1.0
    a = 3.0
    b = 3.0
    low = 5.0
    # high = 10.0
    high = 8.0
    mu0 = generate_mu(m, low, high, a, b)
    mu1 = generate_mu(m, low, high, a, b)
    mu2 = generate_mu(m, low, high, a, b)
    mu3 = generate_mu(m, low, high, a, b)
    log = {'mu0' : mu0, 'mu1' : mu1, 'mu2' : mu2, 'mu3' : mu3}
    N_tot = 500
    # N_tot = 50
    var = 3.0
    # Generate attacker data
    theta_1, theta_2 = generate_theta_normal(mu0, var * np.eye(m), mu1, var * np.eye(m), mu2, var * np.eye(m), mu3, var * np.eye(m), N_tot)
    full_theta = np.concatenate([theta_1, theta_2], axis=2)
    # print(full_theta)
    print("Mu0 : ", mu0)
    print("Mu1 : ", mu1)
    print("Mu2 : ", mu2)
    print("Mu3 : ", mu3)

    xi = 1e6
    A, b, C, d, tL, tU = compute_params(m, K, N_tot, numerator, denominator, full_theta)
    w = np.ones(N_tot + 1)
    w[-1] = 0
    tot_z_dro = FCP_DRO(m, K, N_tot, N_tot, M, tL, tU, A, b, C, d, w, xi)
    # values_DRO = FCP_values(tot_z_dro, full_theta)
    # opt_tot = np.mean(values_DRO)
    opt_tot = utility_robust(full_theta, numerator, denominator, m, tot_z_dro, N_tot, xi, w)

    all_z = []
    all_v = []
    all_times = []
    all_losses = []
    N_lower = 8
    N_upper = 24
    N = N_lower
    while (N <= N_upper):
        tot_theta = np.array(full_theta)
        tot_theta = tot_theta.reshape(N_tot,-1)
        kmeans = KMeans(n_clusters=N).fit(tot_theta)
        all_losses.append(kmeans.inertia_)
        cluster_cent = kmeans.cluster_centers_
        cluster_cent = cluster_cent.reshape(N, m, 4)
        Y = kmeans.predict(tot_theta)
        s = np.zeros(N+1)
        for i in range(len(Y)):
            s[Y[i]] += 1
        A, b, C, d, tL, tU = compute_params(m, K, N, numerator, denominator, cluster_cent)
        # sta = time.perf_counter()
        z_dro = FCP_DRO(m, K, N, N_tot, M, tL, tU, A, b, C, d, s, xi)
        # fin = time.perf_counter()
        diff = np.sqrt(sum((z_dro - tot_z_dro)**2))
        # values_DRO = FCP_values(z_dro, full_theta)
        # opt = np.mean(values_DRO)
        opt = utility_robust(full_theta, numerator, denominator, m, z_dro, N_tot, xi, w)
        all_v.append(opt)
        print("N : ", N, " diff : ", opt_tot - opt)
        all_z.append(z_dro)
        # all_times.append(fin-sta)
        samp = 1
        while (samp <= N_lower):
            num_clusters = int(N/samp)
            kmeans = KMeans(n_clusters=num_clusters).fit(tot_theta)
            Y = kmeans.predict(tot_theta)
            s = np.zeros(num_clusters+1)
            strata = [[] for i in range(num_clusters)]
            data_points = tot_theta.reshape(N_tot, m, 4)
            # sampled_points = []
            # ls = []
            for i in range(len(Y)):
                s[Y[i]] += 1
                strata[Y[i]].append(data_points[i])
            for rep in range(10):
                ls = []
                sampled_points = []
                for i in range(num_clusters):
                    temp = np.random.choice(np.arange(s[i]), samp, replace=False)
                    print(temp)
                    # print([data_points[temp[j]] for j in range(samp)])
                    # sampled_points.extend([data_points[int(temp[j])] for j in range(samp)])
                    sampled_points.extend([strata[i][int(temp[j])] for j in range(samp)])
                    temp_ls = [s[i] / samp for j in range(samp)]
                    ls.extend(temp_ls)
                ls.append(0)
                ls = np.array(ls)
                A, b, C, d, tL, tU = compute_params(m, K, N, numerator, denominator, sampled_points)
                # sta = time.perf_counter()
                z_dro = FCP_DRO(m, K, N, N_tot, M, tL, tU, A, b, C, d, ls, xi)
                # fin = time.perf_counter()
                diff = np.sqrt(sum((z_dro - tot_z_dro)**2))
                # values_DRO = FCP_values(z_dro, full_theta)
                # opt = np.mean(values_DRO)
                opt = utility_robust(full_theta, numerator, denominator, m, z_dro, N_tot, xi, w)
                all_v.append(opt)
                print("N : ", N, " samp : ", samp, " diff : ", 100 * (opt_tot - opt) / opt_tot)
                all_z.append(z_dro)
            # all_times.append(fin-sta)
            samp *= 2
        N += N_lower
    results_location = './simulations/SSG_cluster_vs_sampling_{}_trial_{}.npy'.format(m, trial)
    log['all_z'] = all_z
    log['all_times'] = all_times
    log['tot_z'] = tot_z_dro
    log['tot_v'] = opt_tot
    log['all_v'] = all_v
    log['all_losses'] = all_losses
    np.save(results_location, log) 
